import tensorflow as tf

from keras.layers import *


class Encoder(tf.keras.Model):

    def __init__(self, latent_dim=2, encoder=None, simple=True):
        super(Encoder, self).__init__()
        self.latent_dim = latent_dim
        self.eval = False

        if encoder is None:
            modules = []
            modules.append(Conv2D(filters=32, kernel_size=3, strides=(2, 2), activation='relu'))
            if not simple:
                modules.append(Conv2D(filters=64, kernel_size=3, strides=(2, 2), activation='relu'))
            modules.append(Conv2D(filters=64, kernel_size=3, strides=(2, 2), activation='relu'))
            modules.append(Flatten())

            self.encoder = tf.keras.Sequential(modules)
        else:
            self.encoder = encoder
        self.mean = Dense(latent_dim)
        self.sd = Dense(latent_dim)

    def call(self, input):
        x = self.encoder(input)
        mean = self.mean(x)
        if self.eval:
            sd = tf.zeros_like(mean)
        else:
            sd = tf.exp(0.5 * self.sd(x))
        x = tf.stack([mean, sd], axis=-1)
        return x


class Decoder(tf.keras.Model):

    def __init__(self, latent_dim=2, simple=True):
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim

        modules = []
        if simple:
            modules.append(Dense(7 * 7 * 32, activation=tf.nn.relu))
            modules.append(Reshape(target_shape=(7, 7, 32)))
        else:
            modules.append(Dense(8 * 8 * 64, activation=tf.nn.relu))
            modules.append(Reshape(target_shape=(8, 8, 64)))
        modules.append(Conv2DTranspose(filters=64, kernel_size=3, strides=2, padding='same', activation='relu'))
        if not simple:
            modules.append(Conv2DTranspose(filters=64, kernel_size=3, strides=2, padding='same', activation='relu'))
        modules.append(Conv2DTranspose(filters=32, kernel_size=3, strides=2, padding='same', activation='relu'))
        modules.append(Conv2DTranspose(filters=1, kernel_size=3, strides=1, padding='same', activation='tanh'))


        self.decoder = tf.keras.Sequential(modules)

    def call(self, input):
        input = tf.reshape(input[:, 0, :], [-1, self.latent_dim])
        out = self.decoder(input)
        # We take the first input, since the input itself will be a list of samples from the DPL program
        return out


class VAEClassifier(tf.keras.Model):

    def __init__(self, N=10, latent_dim=2):
        super(VAEClassifier, self).__init__()
        self.N = N
        self.latent_dim = latent_dim

        modules = []
        modules.append(Dense(32))
        modules.append(ReLU())
        modules.append(Dense(24))
        modules.append(ReLU())
        modules.append(Dense(N))
        modules.append(Softmax())

        self.model = tf.keras.Sequential(modules)

    def call(self, input):
        # x = tf.reduce_mean(input, axis=1)
        x = tf.reshape(input[:, 0, :], [-1, self.latent_dim])
        return self.model(x)

